Style Transfer


By Prof. Seungchul Lee
http://iai.postech.ac.kr/
Industrial AI Lab at POSTECH

Table of Contents

1. Style Transfer¶



InĀ [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.applications.vgg16 import VGG16
import cv2
/home/project/.pyenv/versions/3.6.3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6
  return f(*args, **kwds)
Using TensorFlow backend.

1.1. Content Image¶

InĀ [2]:
h_image, w_image = 600, 1000
InĀ [3]:
img_content = cv2.imread('./image_files/postech_flag.jpg')
img_content = cv2.cvtColor(img_content, cv2.COLOR_BGR2RGB)

img_content = cv2.resize(img_content, (w_image, h_image))

plt.figure(figsize = (10,8))
plt.imshow(img_content)
plt.axis('off')
plt.show()

1.2. Style Image¶

InĀ [4]:
img_style = cv2.imread('./image_files/la_muse.jpg')
img_style = cv2.cvtColor(img_style, cv2.COLOR_BGR2RGB)
img_style = cv2.resize(img_style, (w_image, h_image))

plt.figure(figsize = (10,8))
plt.imshow(img_style)
plt.axis('off')
plt.show()

1.3. Pre-trained Model (VGG16)¶

InĀ [5]:
model = VGG16(weights = 'imagenet')

model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 224, 224, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
fc1 (Dense)                  (None, 4096)              102764544 
_________________________________________________________________
fc2 (Dense)                  (None, 4096)              16781312  
_________________________________________________________________
predictions (Dense)          (None, 1000)              4097000   
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________
InĀ [6]:
vgg16_weights = model.get_weights()

# kernel size: [kernel_height, kernel_width, input_ch, output_ch]
weights = {
    'conv1_1' : tf.constant(vgg16_weights[0]),
    'conv1_2' : tf.constant(vgg16_weights[2]),
    
    'conv2_1' : tf.constant(vgg16_weights[4]),
    'conv2_2' : tf.constant(vgg16_weights[6]),
    
    'conv3_1' : tf.constant(vgg16_weights[8]),
    'conv3_2' : tf.constant(vgg16_weights[10]),
    'conv3_3' : tf.constant(vgg16_weights[12]),
    
    'conv4_1' : tf.constant(vgg16_weights[14]),
    'conv4_2' : tf.constant(vgg16_weights[16]),
    'conv4_3' : tf.constant(vgg16_weights[18]),
    
    'conv5_1' : tf.constant(vgg16_weights[20]),
    'conv5_2' : tf.constant(vgg16_weights[22]),
    'conv5_3' : tf.constant(vgg16_weights[24]),
}

# bias size: [output_ch] or [neuron_size]
biases = {
    'conv1_1' : tf.constant(vgg16_weights[1]),
    'conv1_2' : tf.constant(vgg16_weights[3]),
    
    'conv2_1' : tf.constant(vgg16_weights[5]),
    'conv2_2' : tf.constant(vgg16_weights[7]),
    
    'conv3_1' : tf.constant(vgg16_weights[9]),
    'conv3_2' : tf.constant(vgg16_weights[11]),
    'conv3_3' : tf.constant(vgg16_weights[13]),
    
    'conv4_1' : tf.constant(vgg16_weights[15]),
    'conv4_2' : tf.constant(vgg16_weights[17]),
    'conv4_3' : tf.constant(vgg16_weights[19]),
    
    'conv5_1' : tf.constant(vgg16_weights[21]),
    'conv5_2' : tf.constant(vgg16_weights[23]),
    'conv5_3' : tf.constant(vgg16_weights[25]),
}
InĀ [7]:
# input layer: [1, image_height, image_width, channels]
input_content = tf.placeholder(tf.float32, [1, h_image, w_image, 3])
input_style = tf.placeholder(tf.float32, [1, h_image, w_image, 3])
InĀ [8]:
def net(x, weights, biases):
    # First convolution layer
    conv1_1 = tf.nn.conv2d(x, 
                         weights['conv1_1'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv1_1 = tf.nn.relu(tf.add(conv1_1, biases['conv1_1']))
    conv1_2 = tf.nn.conv2d(conv1_1, 
                         weights['conv1_2'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv1_2 = tf.nn.relu(tf.add(conv1_2, biases['conv1_2']))
    maxp1 = tf.nn.max_pool(conv1_2, 
                           ksize = [1, 2, 2, 1], 
                           strides = [1, 2, 2, 1], 
                           padding = 'VALID')
    
    # Second convolution layer
    conv2_1 = tf.nn.conv2d(maxp1, 
                         weights['conv2_1'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv2_1 = tf.nn.relu(tf.add(conv2_1, biases['conv2_1']))
    conv2_2 = tf.nn.conv2d(conv2_1, 
                         weights['conv2_2'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv2_2 = tf.nn.relu(tf.add(conv2_2, biases['conv2_2']))
    maxp2 = tf.nn.max_pool(conv2_2, 
                           ksize = [1, 2, 2, 1], 
                           strides = [1, 2, 2, 1], 
                           padding = 'VALID')

    # third convolution layer
    conv3_1 = tf.nn.conv2d(maxp2, 
                         weights['conv3_1'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv3_1 = tf.nn.relu(tf.add(conv3_1, biases['conv3_1']))
    conv3_2 = tf.nn.conv2d(conv3_1, 
                         weights['conv3_2'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv3_2 = tf.nn.relu(tf.add(conv3_2, biases['conv3_2']))
    conv3_3 = tf.nn.conv2d(conv3_2, 
                         weights['conv3_3'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv3_3 = tf.nn.relu(tf.add(conv3_3, biases['conv3_3']))
    maxp3 = tf.nn.max_pool(conv3_3, 
                           ksize = [1, 2, 2, 1], 
                           strides = [1, 2, 2, 1], 
                           padding = 'VALID')
    
    # fourth convolution layer
    conv4_1 = tf.nn.conv2d(maxp3, 
                         weights['conv4_1'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv4_1 = tf.nn.relu(tf.add(conv4_1, biases['conv4_1']))
    conv4_2 = tf.nn.conv2d(conv4_1, 
                         weights['conv4_2'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv4_2 = tf.nn.relu(tf.add(conv4_2, biases['conv4_2']))
    conv4_3 = tf.nn.conv2d(conv4_2, 
                         weights['conv4_3'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv4_3 = tf.nn.relu(tf.add(conv4_3, biases['conv4_3']))
    maxp4 = tf.nn.max_pool(conv4_3, 
                           ksize = [1, 2, 2, 1], 
                           strides = [1, 2, 2, 1], 
                           padding = 'VALID')
    
    # fifth convolution layer
    conv5_1 = tf.nn.conv2d(maxp4, 
                         weights['conv5_1'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv5_1 = tf.nn.relu(tf.add(conv5_1, biases['conv5_1']))
    conv5_2 = tf.nn.conv2d(conv5_1, 
                         weights['conv5_2'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv5_2 = tf.nn.relu(tf.add(conv5_2, biases['conv5_2']))
    conv5_3 = tf.nn.conv2d(conv5_2, 
                         weights['conv5_3'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv5_3 = tf.nn.relu(tf.add(conv5_3, biases['conv5_3']))
    maxp5  = tf.nn.max_pool(conv5_3, 
                           ksize = [1, 2, 2, 1], 
                           strides = [1, 2, 2, 1], 
                           padding = 'VALID')

    
    return {
        'conv1_1' : conv1_1,
        'conv1_2' : conv1_2,

        'conv2_1' : conv2_1,
        'conv2_2' : conv2_2,

        'conv3_1' : conv3_1,
        'conv3_2' : conv3_2,
        'conv3_3' : conv3_3,

        'conv4_1' : conv4_1,
        'conv4_2' : conv4_2,
        'conv4_3' : conv4_3,

        'conv5_1' : conv5_1,
        'conv5_2' : conv5_2,
        'conv5_3' : conv5_3
    }
InĀ [9]:
layers_style = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']
layers_content = ['conv4_2']

LR = 30

1.4. Image Composition (Generation) as tf.Variable¶

InĀ [10]:
# composite image is the only variable that needs to be updated
input_gen = tf.Variable(tf.random_uniform([1, h_image, w_image, 3], maxval = 255))

1.5. Style Loss and Content Loss¶

InĀ [11]:
def get_gram_matrix(conv_layer):      
    channels = conv_layer.get_shape().as_list()[3]
    conv_layer = tf.reshape(conv_layer, (-1, channels))
    gram_matrix = tf.matmul(tf.transpose(conv_layer), conv_layer)
    return gram_matrix/((conv_layer.get_shape().as_list()[0])*channels)

def get_loss_style(gram_matrix_gen, gram_matrix_ref):
    loss = tf.reduce_mean(tf.square(gram_matrix_gen - gram_matrix_ref))
    return loss

def get_loss_content(gen_layer, ref_layer):
    loss = tf.reduce_mean(tf.square(gen_layer - ref_layer))
    return loss
InĀ [12]:
features_style = net(input_style, weights, biases)
features_content = net(input_content, weights, biases)

features_gen = net(input_gen, weights, biases)
InĀ [13]:
loss_style = 0
for key in layers_style:
    loss_style += get_loss_style(get_gram_matrix(features_gen[key]), get_gram_matrix(features_style[key]))
InĀ [14]:
loss_content = 0
for key in layers_content:
    loss_content += get_loss_content(features_gen[key], features_content[key])
InĀ [15]:
g = 1/(1e1)
loss_total = loss_content + g*loss_style

optm = tf.train.AdamOptimizer(LR).minimize(loss_total)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

1.6. Composite Image¶

InĀ [16]:
n_iter = 1000
n_prt = 100

for itr in range(n_iter + 1):
    sess.run(optm, feed_dict = {input_style: img_style[np.newaxis,:,:,:], 
                                input_content: img_content[np.newaxis,:,:,:]})

    if itr%n_prt == 0:       
        ls = sess.run(loss_style, feed_dict = {input_style: img_style[np.newaxis,:,:,:]})
        lc = sess.run(loss_content, feed_dict = {input_content: img_content[np.newaxis,:,:,:]})
        
        print('Iteration: {}'.format(itr))
        print('Style loss: {}'.format(g*ls))
        print('Content loss: {}\n'.format(lc))
        
        image = sess.run(input_gen)
        image = np.uint8(np.clip(np.round(image), 0, 255)).squeeze()
        plt.figure(figsize = (10,8))
        plt.imshow(image)
        plt.axis('off')
        plt.show()
Iteration: 0
Style loss: 66748.9875
Content loss: 47092.94921875

Iteration: 100
Style loss: 1448.697265625
Content loss: 5410.0263671875

Iteration: 200
Style loss: 1416.83623046875
Content loss: 5701.77490234375

Iteration: 300
Style loss: 1497.24052734375
Content loss: 5976.86767578125

Iteration: 400
Style loss: 1582.4606445312502
Content loss: 6197.85400390625

Iteration: 500
Style loss: 1525.45712890625
Content loss: 5954.1787109375

Iteration: 600
Style loss: 1546.56748046875
Content loss: 5967.0224609375

Iteration: 700
Style loss: 1524.03916015625
Content loss: 5871.39599609375

Iteration: 800
Style loss: 1568.0197265625002
Content loss: 6021.1494140625

Iteration: 900
Style loss: 1499.763671875
Content loss: 5734.31787109375

Iteration: 1000
Style loss: 1482.635546875
Content loss: 5815.00537109375

2. Style Transfer with Total Variance Loss¶

  • Sometimes, the composite images we learn have a lot of high-frequency noise, particularly bright or dark pixels.
  • One common noise reduction method is total variation denoising.
$$\sum_{i,j} \left|x_{i,j} - x_{i+1,j}\right| + \left|x_{i,j} - x_{i,j+1}\right|$$
InĀ [17]:
def get_loss_TV(conv_layer):
    loss = tf.reduce_mean(tf.abs(conv_layer[:,:,1:,:] - conv_layer[:,:,:-1,:])) \
        + tf.reduce_mean(tf.abs(conv_layer[:,1:,:,:] - conv_layer[:,:-1,:,:]))
    return loss
InĀ [18]:
loss_TV = get_loss_TV(input_gen)
InĀ [19]:
loss_total = loss_content + loss_style + 100*loss_TV

optm = tf.train.AdamOptimizer(LR).minimize(loss_total)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
InĀ [20]:
n_iter = 500
n_prt = 100

for itr in range(n_iter + 1):
    sess.run(optm, feed_dict = {input_style : img_style[np.newaxis,:,:,:],
                                input_content : img_content[np.newaxis,:,:,:]})

    if itr%n_prt == 0:       
        ls = sess.run(loss_style, feed_dict = {input_style : img_style[np.newaxis,:,:,:]})
        lc = sess.run(loss_content, feed_dict = {input_content : img_content[np.newaxis,:,:,:]})
        ltv = sess.run(loss_TV)
        
        print('Iteration: {}'.format(itr))
        print('Style loss: {}'.format(g*ls))
        print('Content loss: {}'.format(lc))
        print('TV loss: {}\n'.format(ltv))
        
        image = sess.run(input_gen)
        image = np.uint8(np.clip(np.round(image), 0, 255)).squeeze()
        plt.figure(figsize = (10,8))
        plt.imshow(image)
        plt.axis('off')
        plt.show()
Iteration: 0
Style loss: 63868.86875
Content loss: 54238.24609375
TV loss: 164.00770568847656

Iteration: 100
Style loss: 624.92724609375
Content loss: 16034.91015625
TV loss: 116.65840148925781

Iteration: 200
Style loss: 356.1904296875
Content loss: 10750.283203125
TV loss: 68.02838134765625

Iteration: 300
Style loss: 287.88730468750003
Content loss: 8574.265625
TV loss: 36.62832260131836

Iteration: 400
Style loss: 260.554248046875
Content loss: 7473.9150390625
TV loss: 23.991535186767578

Iteration: 500
Style loss: 248.763525390625
Content loss: 6812.0283203125
TV loss: 20.514698028564453

InĀ [21]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')